from __future__ import print_function

import sys
import argparse
import time
import math

import torch
import torch.backends.cudnn as cudnn
import os

from main_ce import set_loader
from util import AverageMeter
from util import adjust_learning_rate, warmup_learning_rate, accuracy
from util import set_optimizer
from networks.resnet_big import SupConResNet, LinearClassifier, SimSiam, BarlowTwinsModel, DirectDLRModel
from collections import defaultdict
import numpy as np
import random
import re
import wandb
import wilds
from wilds.datasets.synthetic import SpuriousCIFAR10, FeatureDataset, FeatureDataset_Cifar10
from torch.utils.data import DataLoader
from util import plot_singular_values_labels, compute_gram_matrix
from util import spectral_filter_and_normalize, truncate_by_singular_values
from util import plot_metric_vs_sv
from collections import defaultdict
try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass
from util import get_entropy_energy_based_rank
from torchvision import transforms, datasets
from PIL import Image

import matplotlib.pyplot as plt 


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=32,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=1.0,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60, 75, 90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # method
    parser.add_argument('--method', type=str, default='SimCLR',
                        choices=['SupCon', 'SimCLR', 'SimSiam', 'BarlowTwins'], help='choose method')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--head', type=str, default='mlp', choices=['mlp', 'fixed'])
    parser.add_argument('--kappa', type=float, default=1.0)

    parser.add_argument('--dataset', type=str, default='spur_cifar10',
                        choices=['cifar10', 'cifar100', 'spur_cifar10'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--ckpt', type=str, default='save/SimCLR/spur_cifar10_models/SimCLR_spur_cifar10_temp_0.5_trial_27_0_0.0_0.0_0.0_False/last.pth',
                        help='path to pre-trained model')
    parser.add_argument('--use_wandb', action='store_true')
    parser.add_argument('--wandb_name', default='spur-ssl-project')
    parser.add_argument('--entity', default='spur-ssl')
    parser.add_argument('--augmented_features', action='store_true')
    parser.add_argument('--train_set_linear_layer', type=str, default='ds_train', choices=['val', 'train', 'balanced_train', 'ds_train', 'us_train'])
    parser.add_argument('--plot_path', type=str, default='Test/eigenvalues_labels_plot',
                        help='path to save the plots')
    parser.add_argument('--energy_threshold', type=float, default=0.9)
    parser.add_argument('--rank_threshold', type=float, default=0.1)
    parser.add_argument('--spur_str', type=float, default=0.95)
    parser.add_argument('--num_zero_high', type=int, default=0)
    parser.add_argument('--num_zero_low', type=int, default=0)

    opt = parser.parse_args()

    # set the path according to the environment
    opt.data_folder = './datasets/'

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    
    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    if opt.dataset == 'cifar10' or opt.dataset == 'spur_cifar10':
        opt.n_cls = 10
    elif opt.dataset == 'cifar100':
        opt.n_cls = 100
    elif opt.dataset == 'waterbirds': 
        opt.n_cls = 2
    elif opt.dataset == 'cmnist': 
        opt.n_cls = 2
    elif opt.dataset == 'metashift': 
        linear_args.n_cls = 2
    elif opt.dataset == 'celebA': 
        linear_args.n_cls = 2
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt


def set_model(opt):
    # Initialize model based on selected contrastive method
    if opt.method == 'SimCLR':
        model = SupConResNet(name=opt.model, head=opt.head, k=opt.kappa)
    elif opt.method == 'SimSiam': 
        model = SimSiam(name=opt.model)
    elif opt.method == 'BarlowTwins': 
        model = BarlowTwinsModel(name=opt.model)
    elif opt.method == 'DirectDLR': 
        model = DirectDLRModel(name=opt.model)
    else: 
        raise ValueError(f'Contrastive method not supported: {opt.method}')
    
    # Initialize classifier and loss function
    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
    criterion = torch.nn.CrossEntropyLoss()

    # Move to GPU if available
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True
    else:
        raise NotImplementedError('This code requires GPU')

    # Load pretrained contrastive weights if checkpoint is provided
    if opt.ckpt:
        print(f'[INFO] Loading checkpoint from {opt.ckpt}')
        ckpt = torch.load(opt.ckpt, map_location='cpu')
        state_dict = ckpt['model']
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
    else:
        print('[INFO] No checkpoint provided. Using randomly initialized encoder.')
        print('[INFO] Freezing encoder parameters.')
        for param in model.encoder.parameters():
            param.requires_grad = False

    return model, classifier, criterion


def get_features_labels_metadata(model, data_loader):
    model.eval()

    all_features, all_labels, all_metadata = [], [], []
    with torch.no_grad():
        for idx, data in enumerate(data_loader):
            if torch.cuda.is_available():
                images = data[0].cuda(non_blocking=True)
                labels = data[1].cuda(non_blocking=True)
                metadata = data[2].cuda()

            # pass images through the encoder
            features = model.encoder(images) # get features
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
            all_metadata.append(metadata.cpu())

    features = torch.cat(all_features, dim=0) # n * d
    labels = torch.cat(all_labels, dim=0)
    metadata = torch.cat(all_metadata, dim=0)

    return features, labels, metadata

def get_features_labels(model, data_loader):
    model.eval()

    all_features, all_labels = [], []

    with torch.no_grad():
        for idx, (images, labels) in enumerate(data_loader):
            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            # pass images through the encoder
            features = model.encoder(images) # get features
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
    
    features = torch.cat(all_features, dim=0)
    labels = torch.cat(all_labels, dim=0)

    return features, labels

def train(train_loader, feature_loader, model, classifier, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    all_preds = []
    all_labels = []
    all_metadata = []

    correct_per_group = defaultdict(int)
    total_per_group = defaultdict(int)
    
    total_correct, total_samples = 0, 0

    end = time.time()
    for features, labels, metadata in feature_loader:
    # for features, labels in feature_loader:
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            features = features.cuda()
            labels = labels.cuda()
            metadata = metadata.cuda()
        bsz = labels.shape[0]

        outputs = classifier(features)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)

        # update metric
        losses.update(loss.item(), bsz)
        if opt.n_cls < 5: 
            k = opt.n_cls
        else: 
            k = 5
        acc1, acc5 = accuracy(outputs, labels, topk=(1, k))
        top1.update(acc1[0], bsz)

        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
        all_metadata.append(metadata.cpu())

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_metadata = torch.cat(all_metadata, dim=0)

    res, res_str = train_loader.dataset.eval(all_preds, all_labels, all_metadata)

    import gc
    del all_preds, all_labels, all_metadata  # or any huge tensors
    gc.collect()

    del model
    del classifier
    torch.cuda.empty_cache()

    return losses.avg, top1.avg, res


def validate(val_loader, feature_loader, model, classifier, criterion, opt):
    """validation"""
    model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    all_preds = []
    all_labels = []
    all_metadata = []

    correct_per_group = defaultdict(int)
    total_per_group = defaultdict(int)
    
    total_correct, total_samples = 0, 0

    with torch.no_grad():
        end = time.time()
        for features, labels, metadata in feature_loader:
        # for features, labels in feature_loader:
            
            if torch.cuda.is_available():
                features = features.cuda()
                labels = labels.cuda()
                metadata = metadata.cuda()
            bsz = labels.shape[0]

            # forward
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)

            # update metric
            losses.update(loss.item(), bsz)
            if opt.n_cls < 5: 
                k = opt.n_cls
            else: 
                k = 5
            acc1, acc5 = accuracy(outputs, labels, topk=(1, k))
            top1.update(acc1[0], bsz)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
            all_metadata.append(metadata.cpu())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_metadata = torch.cat(all_metadata, dim=0)
    res, res_str = val_loader.dataset.eval(all_preds, all_labels, all_metadata)

    return losses.avg, top1.avg, res


# def main(num_singular_values, uniform=False):
def main(opt, supcon_epoch):

    best_val_acc = 0
    best_train_acc = 0
    best_val_wg_acc = 0
    best_train_wg_acc = 0
    best_val_bg_acc = 0
    best_train_bg_acc = 0
    accuracy_list = []
    # supcon_epoch = 0

    # opt = parse_option()

    # convert the argparse Namespace to a dictionary
    # hyperparameters = vars(opt)  

    # replace with the keys you want
    # selected_keys = ['method', 'dataset', 
    #                  'spur_str', 'spec_reg',
    #                  'trial', 'seed']
    # filtered_hyperparameters = {key: hyperparameters[key] for key in selected_keys if key in hyperparameters}

    # formatted string for the run name
    # run_name = "_".join([f"{value}" for key, value in filtered_hyperparameters.items()])

    # setup the seed
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    # enable deterministic algorithms
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # setup wandb
    # if opt.use_wandb:
    #     wandb.init(
    #         project=opt.wandb_name,
    #         name=run_name,  
    #         config=hyperparameters,  
    #         entity=opt.entity 
    #     )

    # build data loader
    train_loader, val_loader = set_loader(opt)
    # mean = (0.485, 0.456, 0.406)
    # std = (0.229, 0.224, 0.225)
    # normalize = transforms.Normalize(mean=mean, std=std)
    # train_transform = transforms.Compose([
    #         transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])

    # val_transform = transforms.Compose([
    #         transforms.Resize((32, 32)),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])

    # full_dataset = wilds.get_dataset(
    #     dataset=opt.dataset,
    #     root_dir='./datasets',
    #     split_scheme='official', 
    #     spur_str=opt.spur_str)
    # train_dataset = full_dataset.get_subset(opt.train_set_linear_layer, frac=1., transform=train_transform)
    # val_dataset = full_dataset.get_subset('val', frac=1., transform=val_transform)

    # # # dataset = SpuriousCIFAR10(root_dir=opt.data_folder, split_scheme='official', invar_str = 1., spur_str = 0.6)
    # print("Dataset length:", len(train_dataset))
    # sample = train_dataset.dataset.get_input(0)  # This should return a PIL image

    # img_resized = sample.resize((256, 256), resample=Image.NEAREST)
    # img_resized.save('debug_sample_resize.png')  # Or sample.save("debug_sample.png")

    # print("Dataset length:", len(val_dataset))
    # val_sample = val_dataset.dataset.get_input(3000)  # This should return a PIL image

    # val_img_resized = val_sample.resize((256, 256), resample=Image.NEAREST)
    # val_img_resized.save('debug_valsample_resize.png')  # Or sample.save("debug_sample.png")
    # exit()




    # build model and criterion
    model, classifier, criterion = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, classifier)

    # get and save the features from the encoder
    features, labels, metadata = get_features_labels_metadata(model, train_loader)
    # features, labels = get_features_labels(model, train_loader)
    # if uniform == True: 
        # features = spectral_filter_and_normalize(features, num_zero_high=opt.num_zero_high, num_zero_low=num_singular_values)
        # features = truncate_by_singular_values(features, num_singular_values)
    # else: 
        # features = features
    feature_dataset = FeatureDataset(features, labels, metadata)
    # feature_dataset = FeatureDataset_Cifar10(features, labels)
    feature_loader = DataLoader(feature_dataset, batch_size=opt.batch_size, shuffle=True)
    # get and save the features from the encoder
    val_features, val_labels, val_metadata = get_features_labels_metadata(model, val_loader)
    # val_features, val_labels = get_features_labels(model, val_loader)
    # if uniform == True: 
        # val_features = spectral_filter_and_normalize(val_features, num_zero_high=opt.num_zero_high, num_zero_low=num_singular_values)
        # val_features = truncate_by_singular_values(val_features, num_singular_values)
    # else: 
        # val_features = val_features
    val_feature_dataset = FeatureDataset(val_features, val_labels, val_metadata)
    # val_feature_dataset = FeatureDataset_Cifar10(val_features, val_labels)
    
    val_feature_loader = DataLoader(val_feature_dataset, batch_size=opt.batch_size, shuffle=False)

    # train_gram_matrix = compute_gram_matrix(features.T)
    # val_gram_matrix = compute_gram_matrix(val_features.T)


    import gc
    # del metadata  # or any huge tensors
    # del val_metadata  # or any huge tensors
    gc.collect()
    # train_gram_matrix = train_gram_matrix.to(dtype=torch.float32, device='cpu')
    # print(train_gram_matrix.shape)
    # # train_gram_matrix = (train_gram_matrix + train_gram_matrix.T) / 2
    # if not torch.isfinite(train_gram_matrix).all():
    #     raise ValueError("Gram matrix contains NaNs or Infs")
    # print("[DEBUG] Shape:", train_gram_matrix.shape)
    # print("[DEBUG] Symmetric?", torch.allclose(train_gram_matrix, train_gram_matrix.T, atol=1e-5))
    # print("[DEBUG] Finite?", torch.isfinite(train_gram_matrix).all().item())
    # print("[DEBUG] Mean:", train_gram_matrix.mean().item())
    # print("[DEBUG] Std:", train_gram_matrix.std().item())
    # print("[DEBUG] Min/max:", train_gram_matrix.min().item(), train_gram_matrix.max().item())

    # sub_gram = train_gram_matrix[idx][:, idx]
    # labels = labels[idx]

    # plot_singular_values_labels(train_gram_matrix[idx][:, idx], labels[idx], 'train_0', save_path=opt.plot_path)
    # plot_singular_values_labels(val_gram_matrix, val_labels, 0, save_path=opt.plot_path)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        linear_train_loss, train_acc, results = train(train_loader, feature_loader, model, classifier, criterion,
                          optimizer, epoch, opt)
        # linear_train_loss, train_acc = train(train_loader, feature_loader, model, classifier, criterion,
        #                   optimizer, epoch, opt)
        time2 = time.time()

        # train_wg_acc = re.search(r"Worst-group acc:\s*(\d+\.\d+)", train_wg_acc)
        # train_wg_acc = float(train_wg_acc.group(1))
        # avg_acc = results['acc_avg'].item()
        train_wg_acc = results['acc_wg']*100
        train_bg_acc  = results['best_acc']*100
        
        print('Train epoch {}, total time {:.2f}, loss {:.4f}, accuracy {:.2f}, wg accuracy {:.2f}, bg accuracy {:.2f}'.format(
            epoch, time2 - time1, linear_train_loss, train_acc, train_wg_acc, train_bg_acc))    
        # print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
        #     epoch, time2 - time1, train_acc))            
        
        if train_acc > best_train_acc:
            best_train_acc = train_acc
            best_train_wg_acc = train_wg_acc
            best_train_bg_acc = train_bg_acc

        # eval for one epoch
        linear_val_loss, val_acc, val_results = validate(val_loader, val_feature_loader, model, classifier, criterion, opt)
        # linear_val_loss, val_acc = validate(val_loader, val_feature_loader, model, classifier, criterion, opt)

        # val_wg_acc = re.search(r"Worst-group acc:\s*(\d+\.\d+)", val_wg_acc)
        # val_wg_acc = float(val_wg_acc.group(1))

        val_wg_acc = val_results['acc_wg']*100
        val_bg_acc  = val_results['best_acc']*100
        
        print('Val epoch {}, loss {:.4f}, accuracy {:.2f}, wg accuracy {:.2f}, bg accuracy {:.2f}'.format(
            epoch, linear_val_loss, val_acc, val_wg_acc, val_bg_acc))
        # print('Val epoch {}, accuracy:{:.2f}'.format(
        #     epoch, val_acc))
        accuracy_list.append((val_acc, val_wg_acc, val_bg_acc))
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_wg_acc = val_wg_acc
            best_val_bg_acc = val_bg_acc
        elif val_acc == best_val_acc:
            if val_wg_acc > best_val_wg_acc:
                best_val_wg_acc = val_wg_acc
            elif val_wg_acc == best_val_wg_acc:
                if val_bg_acc > best_val_bg_acc: 
                    best_val_bg_acc = val_bg_acc
        
        # if opt.use_wandb:
        #     wandb_dict_1 = {}
        #     wandb_dict_1.update({
        #         'Linear train loss': linear_train_loss,
        #         'Linear val loss': linear_val_loss,
        #         'Linear train accuracy': train_acc, 
        #         'Linear train wg-accuracy': train_wg_acc,
        #         'Linear train bg-accuracy': train_bg_acc,
        #         'Linear val accuracy': val_acc, 
        #         'Linear val wg-accuracy': val_wg_acc,
        #         'Linear val bg-accuracy': val_bg_acc,
        #     })
        #     wandb.log(wandb_dict_1, step=epoch)


    last_acc, last_wg_acc, last_bg_acc = accuracy_list[-1]

    # Calculate the average of the last 10 accuracies and their worst-group accuracies
    last_10_accuracies = accuracy_list[-10:] if len(accuracy_list) >= 10 else accuracy_list
    avg_last_10_acc = sum([acc[0] for acc in last_10_accuracies]) / len(last_10_accuracies)
    avg_last_10_wg_acc = sum([acc[1] for acc in last_10_accuracies]) / len(last_10_accuracies)
    avg_last_10_bg_acc = sum([acc[2] for acc in last_10_accuracies]) / len(last_10_accuracies)
    print(
        "Average of last 10 accuracies: {:.2f}, Average of last 10 worst-group accuracies: {:.2f}, Average of last 10 best-group accuracies: {:.2f}".format(
            avg_last_10_acc, avg_last_10_wg_acc, avg_last_10_bg_acc
        )
    )

    # train_gram_matrix = compute_gram_matrix(features.T)

    # if train_gram_matrix.shape[1] > 5000:
    #     # print('shape[1] is larger than 5000')
    #     # idx = torch.randperm(train_gram_matrix.size(0))[:5000]  
    #     idx = 5000
    # else: 
    #     # print('shape[1] is smaller than 5000')
    #     idx = train_gram_matrix.shape[1]-1
    
    # # print(train_gram_matrix[:idx][:, :idx].shape)
    # # print(labels[:idx].shape)
    # plot_singular_values_labels(train_gram_matrix[:idx][:, :idx], labels[:idx], f'train_{num_singular_values}', save_path=opt.plot_path)

    # val_gram_matrix = compute_gram_matrix(val_features.T)
    # plot_singular_values_labels(val_gram_matrix, val_labels, num_singular_values, save_path=opt.plot_path)

    entropy, effective_rank, energy_based_rank = get_entropy_energy_based_rank(features, opt)
    val_entropy, val_effective_rank, val_energy_based_rank = get_entropy_energy_based_rank(val_features, opt)

    print(f"Train — Entropy: {entropy:.4f}, Effective Rank: {effective_rank:.2f}, Energy-Based Rank: {energy_based_rank:.2f}")
    print(f"Val   — Entropy: {val_entropy:.4f}, Effective Rankuse_wandb: {val_effective_rank:.2f}, Energy-Based Rank: {val_energy_based_rank:.2f}")


    if opt.use_wandb:
        wandb_dict = {}
        wandb_dict.update({
            'Linear train acc': best_train_acc,
            'Linear train worst-group acc': best_train_wg_acc,
            'Linear train best-group acc': best_train_bg_acc,
            'Linear val acc': best_val_acc,
            'Linear val worst-group acc': best_val_wg_acc,
            'Linear val best-group acc': best_val_bg_acc,
        })
        wandb_dict.update({
            'Train linear entropy': entropy,
            'Train linear effective rank': effective_rank,
            'Train linear energy-based rank': energy_based_rank,
        })
        wandb_dict.update({
            'Val linear entropy': val_entropy,
            'Val linear effective rank': val_effective_rank,
            'Val linear energy-based rank': val_energy_based_rank,
        })
        wandb_dict.update({
            'Last linear val acc': last_acc,
            'Last linear val worst-group acc': last_wg_acc,
            'Last linear val best-group acc': last_bg_acc,
            'Average over 10 last linear val acc': avg_last_10_acc,
            'Average over last 10 linear val worst-group acc': avg_last_10_wg_acc, 
            'Average over last 10 linear val best-group acc': avg_last_10_bg_acc
        })

        wandb.log(wandb_dict, step=supcon_epoch)

    print('best accuracy: {:.2f}'.format(best_val_acc), 'and worst-group accuracy: {:.2f}'.format(best_val_wg_acc), 
          'and best-group accuracy: {:.2f}'.format(best_val_bg_acc))
    # print('best accuracy: {:.2f}'.format(best_val_acc))

    print('Last accuracy: {:.2f}, Last worst-group accuracy: {:.2f}, Last best-group accuracy: {:.2f}'.format(last_acc, last_wg_acc, last_bg_acc))
    print('Train entropy: {:.2f}, effective rank: {}, and energy-based rank: {}'.format(entropy, effective_rank, energy_based_rank))
    print('Val entropy: {:.2f}, effective rank: {}, and energy-based rank: {}'.format(val_entropy, val_effective_rank, val_energy_based_rank))
    print('Average last 10 accuracies: {:.2f}, Average last 10 worst-group accuracies: {:.2f}, Average last 10 best-group accuracies: {:.2f}'
          .format(avg_last_10_acc, avg_last_10_wg_acc, avg_last_10_bg_acc))
    return effective_rank, val_effective_rank, train_acc, train_wg_acc, train_bg_acc, val_acc, val_wg_acc, val_bg_acc

if __name__ == '__main__':
    main(opt)

# if __name__ == '__main__':
    
#     u_train_ranks,  u_val_ranks  = [], []
#     u_train_accs,   u_val_accs   = [], []
#     u_train_wg_accs, u_val_wg_accs = [], []
#     u_train_bg_accs, u_val_bg_accs = [] , []

#     train_ranks,  val_ranks  = [], []
#     train_accs,   val_accs   = [], []
#     train_wg_accs, val_wg_accs = [], []
#     train_bg_accs, val_bg_accs = [] , []
#     d = 512
#     train_rank, val_rank, train_acc, train_wg_acc, train_bg_acc, val_acc, val_wg_acc, val_bg_acc = main(num_singular_values=0, uniform=False)
#     for num_singular_values in range(1, d+1):
#         print(f"\n[INFO] Running main() with num_singular_values = {num_singular_values}")
#         u_train_rank, u_val_rank, u_train_acc, u_train_wg_acc, u_train_bg_acc, u_val_acc, u_val_wg_acc, u_val_bg_acc = main(num_singular_values=num_singular_values, uniform=True)
        

#         # print(u_train_rank, u_val_rank)
#         # print(u_train_acc.cpu().numpy(), u_train_wg_acc, u_train_bg_acc)
#         # print(u_val_acc.cpu().numpy(), u_val_wg_acc, u_val_bg_acc)
#         # print(train_rank, val_rank)
#         # print(train_acc.cpu().numpy(), train_wg_acc, train_bg_acc)
#         # print(val_acc.cpu().numpy(), val_wg_acc, val_bg_acc)
        
#         u_train_ranks.append(u_train_rank)
#         u_val_ranks.append(u_val_rank)
#         u_train_accs.append(u_train_acc.cpu().numpy())
#         u_val_accs.append(u_val_acc.cpu().numpy())
#         u_train_wg_accs.append(u_train_wg_acc)
#         u_val_wg_accs.append(u_val_wg_acc)
#         u_train_bg_accs.append(u_train_bg_acc)
#         u_val_bg_accs.append(u_val_bg_acc)

#         train_ranks.append(train_rank)
#         val_ranks.append(val_rank)
#         train_accs.append(train_acc.cpu().numpy())
#         val_accs.append(val_acc.cpu().numpy())
#         train_wg_accs.append(train_wg_acc)
#         val_wg_accs.append(val_wg_acc)
#         train_bg_accs.append(train_bg_acc)
#         val_bg_accs.append(val_bg_acc)
    
    
#     x = np.arange(1, d + 1)

#     plot_metric_vs_sv(
#         x,
#         u_train_ranks, train_ranks,
#         ylabel='Effective Rank',
#         title='Train Effective Rank vs Singular Value Count',
#         filename='train_effective_rank_vs_singular_values.png'
#     )

#     plot_metric_vs_sv(
#         x,
#         u_val_ranks, val_ranks,
#         ylabel='Effective Rank',
#         title='Val Effective Rank vs Singular Value Count',
#         filename='val_effective_rank_vs_singular_values.png'
#     )

#     plot_metric_vs_sv(
#         x,
#         u_train_accs, train_accs,
#         u_best=u_train_bg_accs, u_worst=u_train_wg_accs,
#         best=train_bg_accs,     worst=train_wg_accs,
#         ylabel='Accuracy',
#         title='Train Accuracy vs Singular Value Count',
#         filename='train_accuracy_vs_singular_values.png'
#     )

#     plot_metric_vs_sv(
#         x,
#         u_val_accs, val_accs,
#         u_best=u_val_bg_accs, u_worst=u_val_wg_accs,
#         best=val_bg_accs,     worst=val_wg_accs,
#         ylabel='Accuracy',
#         title='Val Accuracy vs Singular Value Count',
#         filename='val_accuracy_vs_singular_values.png'
#     )

#     plot_metric_vs_sv(
#         x,
#         u_train_wg_accs, train_wg_accs,
#         ylabel='Worst-Group Accuracy',
#         title='Train Worst-Group Accuracy vs Singular Value Count',
#         filename='train_wg_accuracy_vs_singular_values.png'
#     )

#     plot_metric_vs_sv(
#         x,
#         u_val_wg_accs, val_wg_accs,
#         ylabel='Worst-Group Accuracy',
#         title='Val Worst-Group Accuracy vs Singular Value Count',
#         filename='val_wg_accuracy_vs_singular_values.png'
#     )


